# Ryan Copus, Ryan Hübert and Paige Pellaton
# "Trading Diversity? Judicial Diversity and Case Outcomes in Federal Courts"
# American Political Science Review

# File name: chp_apsr_05_outcomes.py
# Last revision date: March 24, 2024
# Questions or comments? Contact Ryan Hübert: https://ryanhubert.com/

# What does this script do?
# This script tests for miscoded outcomes.

# Last pre-print execution of this code:
# > Date: March 24, 2024
# > Machine: MacBook Pro 14" (2021 model) with Apple M1 Max chip and 64 GB RAM
# > OS: macOS Sonoma 14.4
# > Python: version 3.10

################################################################################
# Load packages and set options
################################################################################

import pandas as pd
import os
import h2o
from h2o.estimators.random_forest import H2ORandomForestEstimator
from h2o.estimators.glm import H2OGeneralizedLinearEstimator
import time
import re
import polars, pyarrow

################################################################################
# Directory management
################################################################################

# Define working directory
wdir = re.sub("/[Cc]ode/?","",os.getcwd())

cdir = wdir + "/outs/" + max([x for x in os.listdir(wdir + "/outs") if x[0:3] == "202"])

rdir = cdir + "/randomization"
if not os.path.exists(rdir):
    os.mkdir(rdir)

mmem = input("How much memory do you want to use? ")

################################################################################
## Import and clean the data
################################################################################

dfile = wdir + "/data/chp_apsr_case_data.csv"

# Import case and plaintiff identity data amd then merge together
df = pd.read_csv(dfile, dtype = object)
df = df.merge(pd.read_csv(wdir + "/outs/chp_apsr_plaintiff_identities.csv").loc[:,["OPEN_ID", 'pla_female', 'pla_nonwhite_w']], on="OPEN_ID", how="left")

df["miscoded"] = 0
df.loc[(df["OUTCOME_IDB"] != df["OUTCOME_RECODED"]), "miscoded"] = 1

# Define the subsets of cases for which we will run randomization checks
tmp = pd.read_csv(cdir + "/model_index.csv")
mask = (tmp["Treatment"].isin(["nontraditional", "republican"])) & (tmp["Outcome"]=="settlement") & (tmp["Plaintiff_Shares"].isnull()) & (tmp["Race_Coding"].isnull()) & (tmp["President_Controls"]==0) & (tmp["SCALES"]==0)
subset = (df["OPEN_ID"].isnull())
for f in tmp.loc[mask, "Model_ID"]:
    subset = subset | (df["OPEN_ID"].isin(pd.read_csv(cdir + "/masks/" + str(f).zfill(3) + ".csv")["OPEN_ID"]))

# Remove rows from df that we won't be using
df = df.loc[subset,:]

# Define sets of pretreatment predictors
pvars1 = ["block", "NOS", "SECTION", "ORIGIN", "JURIS", "JURY", "PROSE", "COUNTY_NAME"]
pvars2 = ['def_count', 'oth_count', 'pla_count', 'def_attorney_count', 'oth_attorney_count',
          'pla_attorney_count', 'def_prose_count', 'oth_prose_count', 'pla_prose_count',
          'def_anonymous', 'oth_anonymous', 'pla_anonymous', 'def_repeat_party_count',
          'oth_repeat_party_count', 'pla_repeat_party_count', 'pla_female', 'pla_nonwhite_w']
jvars1 = ['jid', 'president', 'aba_rating']
jvars2 = ['chief', 'senior', 'republican', 'woman', 'black', 'latino', 'asian', 'white']

# All judges with more than 1,250 cases
tmp = df["jid"].value_counts()
# df.loc[(df["jid"].isin(tmp[tmp>1000].index)), "jid"]
df.loc[(~df["jid"].isin(tmp[tmp>1000].index)), "jid"] = "9999999"

# Reduce the memory impact of the dataframe
df = df.loc[:, ["OPEN_ID","miscoded"] + list(set(jvars1 + jvars2 + pvars1 + pvars2))]

################################################################################
## Set up for the machine learning algorithms that will be run
################################################################################

# Define the learners
base_learners = {}

base_learners["my_ols"] = H2OGeneralizedLinearEstimator(family="binomial",
                                                        lambda_=0,
                                                        standardize=False,
                                                        nfolds=10,
                                                        keep_cross_validation_predictions=True,
                                                        keep_cross_validation_models=True,
                                                        keep_cross_validation_fold_assignment=True,
                                                        seed=2020)
base_learners["my_lasso"] = H2OGeneralizedLinearEstimator(nfolds=10,
                                                          fold_assignment="Modulo",
                                                          keep_cross_validation_predictions=True,
                                                          keep_cross_validation_models=True,
                                                          keep_cross_validation_fold_assignment=True,
                                                          family="binomial",  # for logistic regression
                                                          alpha=1,  # for LASSO
                                                          nlambdas=100,
                                                          seed=2020)
base_learners["my_rf"] = H2ORandomForestEstimator(nfolds=10,
                                                  fold_assignment="Modulo",
                                                  keep_cross_validation_predictions=True,
                                                  keep_cross_validation_models=True,
                                                  keep_cross_validation_fold_assignment=True,
                                                  ntrees=500,
                                                  seed=2020)

base_learners["my_ensemble"] = H2OGeneralizedLinearEstimator(family="binomial",
                                                             lambda_=0,
                                                             standardize=False,
                                                             non_negative=True,
                                                             nfolds=10,
                                                             keep_cross_validation_predictions=True,
                                                             keep_cross_validation_models=True,
                                                             keep_cross_validation_fold_assignment=True,
                                                             seed=2020)

# Define a function that extracts what we need from the models
def GetStuff(model, vimp_frame, perf_frame, roc_frame, preds_frame, tf = None):

    # Get variable importance
    tmp0 = pd.DataFrame(model.varimp(), columns=['var', 'rel_imp', 'scaled_imp', 'percentage'])
    tmp0["outvar"] = y
    tmp0["model"] = m
    tmp0["algorithm"] = alg
    vimp_frame = pd.concat([vimp_frame, tmp0], axis=0)
    del tmp0

    # Get performance statistics
    tmp0 = [y, round(model.mse(xval=True), 6), round(model.auc(xval=True), 6),
            m, alg, len(dstats),
            len(dstats.loc[dstats[y] == 1, "jid"].value_counts()),
            len(dstats.loc[dstats[y] == 0, "jid"].value_counts())]

    tmp0 = pd.DataFrame([tmp0], columns=["outvar", "mse", "auc", "model", "algorithm", "obs", "tr_judges", "co_judges"])
    perf_frame = pd.concat([perf_frame, tmp0], axis=0)
    del tmp0

    # Get ROC curve
    tmp0 = pd.DataFrame({"fpr": model.model_performance(xval=True).fprs,
                         "tpr": model.model_performance(xval=True).tprs,
                         "outvar": y, "model": m, "algorithm": alg})
    roc_frame = pd.concat([roc_frame, tmp0], axis=0)
    del tmp0

    # Get cross-validation predictions
    with h2o.utils.threading.local_context(polars_enabled=True, datatable_enabled=True):
        model_preds = pd.concat([tf["OPEN_ID"].as_data_frame(),
                                 model.cross_validation_holdout_predictions().as_data_frame()["p1"]], axis=1)


    model_preds = model_preds.rename(columns={"p1": "prediction"})
    model_preds["outvar"] = y
    model_preds["model"] = m #if of is None else m + ">" + m
    model_preds["algorithm"] = alg
    model_preds["pred_type"] = "cv"
    preds_frame = pd.concat([preds_frame, model_preds], axis=0)

    return [vimp_frame, perf_frame, roc_frame, preds_frame, model_preds]

# Consolidate categorical variables that have low counts to improve model performance
def Preprocess(datafr):
    datafr = pd.DataFrame(datafr).assign(train=1)
    # Get rid of categories that appear too infrequently
    datafr["division"] = datafr["block"].str.extract("^([a-z]+[0-9]+)")
    for v in [x for x in pvars1 if x != "block"]:
        tmp = datafr.groupby("division")[v].value_counts().reset_index(name="n")
        tmp = tmp.loc[tmp["n"] <= 20, :].sort_values("n")
        datafr = datafr.merge(tmp, how="left", on=['division', v]).sort_values("n")
        datafr.loc[(~datafr["n"].isnull()), v] = "OTHER_LOW_FREQUENCY"
        datafr = datafr.drop("n", axis=1)
        del tmp
    datafr = datafr.drop("division", axis=1)
    datafr = datafr.drop("train", axis=1)
    return datafr

################################################################################
## Run the models
################################################################################

# Define the response variable for the randomization check:
#   i.e., whether a judge is a nontraditional judge
y = "miscoded"

fn = y + ".csv"

# Launch cluster and create a h2o training set
h2o.init(max_mem_size=str(mmem) + "G")
h2o.no_progress()

## Create a mask for the cases we want to do predictions for
print("\n\n" + "=" * 80 + "\nNow running models on all cases\n" + "=" * 80 + "\n\n")

roc = pd.DataFrame()
perf = pd.DataFrame()
preds = pd.DataFrame()
vimps = pd.DataFrame()

for m in ["benchmark", "saturated"]:

    print("\n" + "     >> Running " + m + " model ")

    ## Create the training frame and out-of-sample prediction frame (for PDD SCORES ONLY)
    tf = h2o.H2OFrame(Preprocess(df))

    for v in [y] + pvars1 + jvars1:
        tf[v] = tf[v].asfactor()

    with h2o.utils.threading.local_context(polars_enabled=True, datatable_enabled=True):
        dstats = tf[:, [y, 'jid', 'block']].as_data_frame()

    pvars = pvars1 + pvars2

    if m != "benchmark":
        pvars = pvars1 + pvars2 + jvars1 + jvars2

    with h2o.utils.threading.local_context(polars_enabled=True, datatable_enabled=True):
        eft = pd.DataFrame(tf[:, ["OPEN_ID", y]].as_data_frame())


    for alg in base_learners:
        if "my_ens" in alg:
            continue
        h2o.show_progress()
        base_learners[alg].train(x=pvars, y=y, training_frame=tf)
        h2o.no_progress()
        vimps, perf, roc, preds, eft1 = GetStuff(base_learners[alg], vimps, perf, roc, preds, tf)
        eft = eft.merge(eft1.loc[:, ["OPEN_ID", "prediction"]].rename(columns={"prediction": alg}))
        del eft1

    ## Estimate the ensemble
    eft = h2o.H2OFrame(eft)
    eft[y] = eft[y].asfactor()
    h2o.show_progress()
    base_learners["my_ensemble"].train(x=[x for x in eft.columns if 'my_' in x], y=y, training_frame=eft)
    h2o.no_progress()
    vimps, perf, roc, preds, eft1 = GetStuff(base_learners["my_ensemble"], vimps, perf, roc, preds, eft)

    del dstats

for f in ["roc", "perf", "preds", "vimps"]:
    eval(f).to_csv(rdir + "/" + f + "_" + fn, mode="w", index=False, header=True)

h2o.remove_all()
h2o.cluster().shutdown()
del roc, perf, preds
time.sleep(5)